# -*- coding: utf-8 -*-
"""
Created on Tue Mar 12 18:58:35 2024

@author: Administrator
"""

import state_extraction
import numpy as np
import sys
import os

import keras
from keras.datasets import mnist
from keras.datasets import cifar10
from keras import models
from keras import layers
from keras import optimizers
import matplotlib.pyplot as plt
import tools
import time
import scipy.io as scio
import cv2

from PIL import Image
from keras.preprocessing.image import img_to_array
from keras.preprocessing.image import load_img
from keras.applications.resnet50 import preprocess_input as preprocess_input_resnet50 # 注意！使用那个网络就必须import对应网络中的preprocess_input
from keras.applications.mobilenet import preprocess_input as preprocess_input_mobilenet # 注意！使用那个网络就必须import对应网络中的preprocess_input
from keras.applications.xception import preprocess_input as preprocess_input_xception # 注意！使用那个网络就必须import对应网络中的preprocess_input
from keras.applications.efficientnet import preprocess_input as preprocess_input_efficientnet # 注意！使用那个网络就必须import对应网络中的preprocess_input
from keras.applications.imagenet_utils import decode_predictions
# from keras.applications.vgg19 import VGG19 # 模型庞大不可使用
# from keras.applications.vgg16 import VGG16 # 模型庞大不可使用
from keras.applications.resnet50  import  ResNet50 # 各模型准确率参考https://keras.io/api/applications/
from keras.applications.mobilenet  import  MobileNet
from keras.applications.xception import Xception
from keras.applications.efficientnet import EfficientNetB0


def normalize_sequences(sequences, dimension=784):
    results = np.zeros((len(sequences), dimension))
    for i, sequence in enumerate(sequences):
        results[i] = sequence.flatten()/255
    return results


def vectorize_labels(labels, label_num):
    loc_results = np.zeros((len(labels), label_num))
    for i, label in enumerate(labels):
        loc_results[i, int(label)] = 1.0
    return loc_results

def closest(lst, K):
     lst = np.asarray(lst)
     idx = (np.abs(lst - K)).argmin()
     return lst[idx]

def conductance_mapping(weights,conductances,dimension=4):
    weight_max = max(abs(weights.flatten()))
    loc_conductances = conductances/conductances[-1]*weight_max
    loc_conductances = list(-loc_conductances)[::-1]+list(loc_conductances)
    loc_weights = weights*1
    if dimension==4:
        for ii in range(weights.shape[0]):
            for ii1 in range(weights.shape[1]):
                for ii2 in range(weights.shape[2]):
                    for ii3 in range(weights.shape[3]):
                        loc_weights[ii,ii1,ii2,ii3] = closest(loc_conductances,weights[ii,ii1,ii2,ii3])
    elif dimension==2:
        for ii in range(weights.shape[0]):
            for ii1 in range(weights.shape[1]):
                loc_weights[ii,ii1] = closest(loc_conductances,weights[ii,ii1])
    else:
        for ii,ii2 in weights:
            loc_weights[ii]=closest(loc_conductances,ii2)
    return loc_weights

def build_convnet_model(label_num, imageDimension, kernerNum = 128, lr=1e-3):
    model = models.Sequential()
    model.add(layers.Conv2D(kernerNum, (3, 3), activation='relu',input_shape=imageDimension,padding='same'))
    model.add(layers.MaxPooling2D((2, 2)))
    model.add(layers.Conv2D(kernerNum*2, (3, 3), activation='relu',input_shape=list(map(lambda x: x/2,imageDimension[:2]))+[kernerNum],padding='same'))
    model.add(layers.MaxPooling2D((2, 2)))
    model.add(layers.Conv2D(kernerNum*4, (3, 3), activation='relu',input_shape=list(map(lambda x: x/4,imageDimension[:2]))+[kernerNum*2],padding='same'))
    model.add(layers.MaxPooling2D((2, 2)))
    model.add(layers.Flatten())
    model.add(layers.Dense(100, activation='relu'))
    model.add(layers.Dense(label_num, activation='softmax')) # the parameter is the class number, related to the dataset.
    model.compile(optimizer=optimizers.RMSprop(learning_rate=lr), loss='categorical_crossentropy', metrics=['accuracy'])
    return model

def build_model_net(label_num, input_size, lr=1e-3):
    model = models.Sequential()
    model.add(layers.Dense(300,activation='relu',input_shape=(input_size,)))
    model.add(layers.Dense(label_num, activation='softmax')) # the parameter is the class number, related to the dataset.
    model.compile(optimizer=optimizers.RMSprop(learning_rate=lr), loss='categorical_crossentropy', metrics=['accuracy'])
    return model

def net_model_training_precision(model, train_dataset, test_dataset, label_num, batch_size=128, epochs=30, label_start=0, precision=8, plotTitle=None, savefile=None):
    loc_train_data, loc_test_data = train_dataset[:,:-1],test_dataset[:,:-1]
    if label_start == 1:
        loc_train_labels, loc_test_labels = vectorize_labels(train_dataset[:,-1]-1, label_num),vectorize_labels(test_dataset[:,-1]-1, label_num)
    if label_start == 0:
        loc_train_labels, loc_test_labels = vectorize_labels(train_dataset[:,-1], label_num),vectorize_labels(test_dataset[:,-1], label_num)
    # model fitting
    loc_temp = model.evaluate(loc_test_data,loc_test_labels)
    loc_history_ret = {'val_accuracy': [loc_temp[1]], 'val_loss': [loc_temp[0]]}
    for loc_i in range(epochs):
        model.fit(loc_train_data,loc_train_labels,batch_size=batch_size,epochs=1,validation_data=(loc_test_data,loc_test_labels))
        if precision!=0: # 修改权重为器件电导值
            loc_conductances = np.loadtxt(str(precision)+' Bit_params.txt')[:,0]
            for loc_layer in filter(lambda x: ('conv2d' in x.name),model.layers): # change kernel weights to match device conductances
                loc_layer.set_weights([conductance_mapping(loc_layer.get_weights()[0],loc_conductances,dimension=4),loc_layer.get_weights()[1]]) # 卷积层权重和偏置
            for loc_layer in filter(lambda x: ('dense' in x.name),model.layers): # change kernel weights to match device conductances
                loc_layer.set_weights([conductance_mapping(loc_layer.get_weights()[0],loc_conductances,dimension=2),loc_layer.get_weights()[1]]) # 全连接层权重和偏置
        loc_temp = model.evaluate(loc_test_data,loc_test_labels)
        loc_history_ret['val_accuracy'].append(loc_temp[1])
        loc_history_ret['val_loss'].append(loc_temp[0])
  
    # # model parameter extraction
    # loc_layer_outputs = [layer.output for layer in model.layers[:8]]
    # loc_activation_model = models.Model(inputs=model.input, outputs=loc_layer_outputs)
    # plotting
    loc_result_matrix = generate_result_matrix(model, loc_test_data, loc_test_labels)
    tools.matrixImaging(loc_result_matrix, scale=None, cmap='Blues', xlabel='Predicted labels',ylabel='Target labels', xticks=list(range(label_num)), yticks=list(range(label_num)), title = 'Confusion matrix')
    plt.figure(1)
    ax1 = plt.subplot(111)
    ax2 = ax1.twinx()
    ax1.plot(range(0,epochs+1),loc_history_ret['val_accuracy'],'r-o')
    ax1.set_ylim((0,1))
    ax2.plot(range(0,epochs+1),loc_history_ret['val_loss'],'b-o')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Accuracy',color='r')
    ax2.set_ylabel('Loss',color='b')
    if plotTitle!=None:
        plt.title(plotTitle)
    if savefile!=None:
        np.savetxt(savefile+'_accuracy.txt',np.array([loc_history_ret['val_accuracy']]).T)
        np.savetxt(savefile+'_loss.txt',np.array([loc_history_ret['val_loss']]).T)
        np.savetxt(savefile+'_result_matrix.txt',loc_result_matrix)
        plt.savefig(savefile+'_epochs.jpg', dpi=300, bbox_inches = 'tight')
    plt.show()
    return loc_history_ret, loc_result_matrix

def convnet_model_training(model, train_dataset, test_dataset, imageDimension, label_num, batch_size=128, epochs=30, label_start=0):
    # dataset preprocessing
    loc_train_data, loc_test_data = train_dataset[:,:-1].reshape([len(train_dataset)]+imageDimension),test_dataset[:,:-1].reshape([len(test_dataset)]+imageDimension)
    if label_start == 1:
        loc_train_labels, loc_test_labels = vectorize_labels(train_dataset[:,-1]-1, label_num),vectorize_labels(test_dataset[:,-1]-1, label_num)
    if label_start == 0:
        loc_train_labels, loc_test_labels = vectorize_labels(train_dataset[:,-1], label_num),vectorize_labels(test_dataset[:,-1], label_num)
    # model fitting
    loc_temp = model.evaluate(loc_test_data,loc_test_labels)
    loc_history = model.fit(loc_train_data,loc_train_labels,batch_size=batch_size,epochs=epochs,validation_data=(loc_test_data,loc_test_labels))
    loc_history.history['val_accuracy'].insert(0,loc_temp[1])
    loc_history.history['val_loss'].insert(0,loc_temp[0])
    # # model parameter extraction
    # loc_layer_outputs = [layer.output for layer in model.layers[:8]]
    # loc_activation_model = models.Model(inputs=model.input, outputs=loc_layer_outputs)
    # plotting
    loc_result_matrix = generate_result_matrix(model, loc_test_data, loc_test_labels)
    tools.matrixImaging(loc_result_matrix, scale=None, cmap='Blues', xlabel='Predicted labels',ylabel='Target labels', xticks=list(range(label_num)), yticks=list(range(label_num)), title = 'Confusion matrix')
    plt.figure(1)
    ax1 = plt.subplot(111)
    ax2 = ax1.twinx()
    ax1.plot(range(0,epochs+1),loc_history.history['val_accuracy'],'r-o')
    ax2.plot(range(0,epochs+1),loc_history.history['val_loss'],'b-o')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Accuracy',color='r')
    ax2.set_ylabel('Loss',color='b')
    plt.show()
    return loc_history, loc_result_matrix

def convnet_model_training_precision(model, train_dataset, test_dataset, imageDimension, label_num, batch_size=128, epochs=30, label_start=0, precision=8, plotTitle=None, savefile=None):
    # dataset preprocessing
    loc_train_data, loc_test_data = train_dataset[:,:-1].reshape([len(train_dataset)]+imageDimension),test_dataset[:,:-1].reshape([len(test_dataset)]+imageDimension)
    if label_start == 1:
        loc_train_labels, loc_test_labels = vectorize_labels(train_dataset[:,-1]-1, label_num),vectorize_labels(test_dataset[:,-1]-1, label_num)
    if label_start == 0:
        loc_train_labels, loc_test_labels = vectorize_labels(train_dataset[:,-1], label_num),vectorize_labels(test_dataset[:,-1], label_num)
    # model fitting
    loc_temp = model.evaluate(loc_test_data,loc_test_labels)
    loc_history_ret = {'val_accuracy': [loc_temp[1]], 'val_loss': [loc_temp[0]]}
    for loc_i in range(epochs):
        model.fit(loc_train_data,loc_train_labels,batch_size=batch_size,epochs=1,validation_data=(loc_test_data,loc_test_labels))
        if precision!=0: # 修改权重为器件电导值
            loc_conductances = np.loadtxt(str(precision)+' Bit_params.txt')[:,0]
            for loc_layer in filter(lambda x: ('conv2d' in x.name),model.layers): # change kernel weights to match device conductances
                loc_layer.set_weights([conductance_mapping(loc_layer.get_weights()[0],loc_conductances,dimension=4),loc_layer.get_weights()[1]]) # 卷积层权重和偏置
            for loc_layer in filter(lambda x: ('dense' in x.name),model.layers): # change kernel weights to match device conductances
                loc_layer.set_weights([conductance_mapping(loc_layer.get_weights()[0],loc_conductances,dimension=2),loc_layer.get_weights()[1]]) # 全连接层权重和偏置
        loc_temp = model.evaluate(loc_test_data,loc_test_labels)
        loc_history_ret['val_accuracy'].append(loc_temp[1])
        loc_history_ret['val_loss'].append(loc_temp[0])
    # # model parameter extraction
    # loc_layer_outputs = [layer.output for layer in model.layers[:8]]
    # loc_activation_model = models.Model(inputs=model.input, outputs=loc_layer_outputs)
    # plotting
    loc_result_matrix = generate_result_matrix(model, loc_test_data, loc_test_labels)
    tools.matrixImaging(loc_result_matrix, scale=None, cmap='Blues', xlabel='Predicted labels',ylabel='Target labels', xticks=list(range(label_num)), yticks=list(range(label_num)), title = 'Confusion matrix')
    plt.figure(1)
    ax1 = plt.subplot(111)
    ax2 = ax1.twinx()
    ax1.plot(range(0,epochs+1),loc_history_ret['val_accuracy'],'r-o')
    ax1.set_ylim((0,1))
    ax2.plot(range(0,epochs+1),loc_history_ret['val_loss'],'b-o')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Accuracy',color='r')
    ax2.set_ylabel('Loss',color='b')
    if plotTitle!=None:
        plt.title(plotTitle)
    if savefile!=None:
        np.savetxt(savefile+'_accuracy.txt',np.array([loc_history_ret['val_accuracy']]).T)
        np.savetxt(savefile+'_loss.txt',np.array([loc_history_ret['val_loss']]).T)
        np.savetxt(savefile+'_result_matrix.txt',loc_result_matrix)
        plt.savefig(savefile+'_epochs.jpg', dpi=300, bbox_inches = 'tight')
    plt.show()
    return loc_history_ret, loc_result_matrix

def model_weight_change(model,conductances):
    loc_layer_num = len(model.layers)
    loc_count = 0
    for loc_layer in model.layers:
        if (type(loc_layer)==keras.src.layers.convolutional.conv2d.Conv2D):
            loc_weights_temp = loc_layer.get_weights()
            if len(loc_weights_temp) == 2:
                loc_layer.set_weights([conductance_mapping(loc_weights_temp[0],conductances,dimension=4),loc_weights_temp[1]]) # 卷积层权重和偏置
            else:
                loc_layer.set_weights([conductance_mapping(loc_weights_temp[0],conductances,dimension=4)]) # 卷积层权重
        if (type(loc_layer)==keras.src.layers.convolutional.separable_conv2d.SeparableConv2D):
            loc_weights_temp = loc_layer.get_weights()
            if len(loc_weights_temp) == 2:
                loc_layer.set_weights([conductance_mapping(loc_weights_temp[0],conductances,dimension=4),conductance_mapping(loc_weights_temp[1],conductances,dimension=4)]) # 卷积层权重和偏置
        if type(loc_layer)==keras.src.layers.core.dense.Dense:
            loc_layer.set_weights([conductance_mapping(loc_layer.get_weights()[0],conductances,dimension=2),loc_layer.get_weights()[1]]) # 全连接层权重和偏置
        loc_count += 1
        print("Done: %d//%d"%(loc_count,loc_layer_num))
    return model

def generate_result_matrix(model,test_data,test_labels):
    loc_probability = model.predict(test_data)
    result_matrix = np.zeros((test_labels.shape[1],test_labels.shape[1]))
    loc_size = test_data.shape[0]
    for loc_i in range(loc_size):
        loc_row_N = np.where(test_labels[loc_i]==1)[0][0]
        loc_column_N = np.where(loc_probability[loc_i]==max(loc_probability[loc_i]))[0][0]
        result_matrix[loc_row_N,loc_column_N]=result_matrix[loc_row_N,loc_column_N]+1         
    return result_matrix

def preprocess_image(filepath): # 裁剪图片，参考https://blog.csdn.net/tjuyanming/article/details/105298043?utm_source=miniapp_weixin
    X_val = np.zeros((224, 224, 3), dtype=np.float32)
    
    # Load (as BGR)
    img = cv2.imread(filepath)
    
    # Resize
    height, width, _ = img.shape
    new_height = height * 256 // min(img.shape[:2])
    new_width = width * 256 // min(img.shape[:2])
    img = cv2.resize(img, (new_width, new_height), interpolation=cv2.INTER_CUBIC)
    
    # Crop
    height, width, _ = img.shape
    startx = width//2 - (224//2)
    starty = height//2 - (224//2)
    img = img[starty:starty+224,startx:startx+224]
    assert img.shape[0] == 224 and img.shape[1] == 224, (img.shape, height, width)
    
    # Save (as RGB)
    X_val[:,:,:] = img[:,:,::-1]
    return X_val



if __name__ == "__main__":
    work_directory = 'RawData/'
    
    # # 9 Bit
    # header = "9 Bit"
    # factor = 1.324
    # data,params = state_extraction.get_data(work_directory)
    # rdata,rparams = state_extraction.refine_data(data,params,factor)
    # state_extraction.draw_data(rdata,rparams,title=header)
    # states = np.array(list(map(lambda x: [x['mean'],x['standard deviation']], rparams)))
    
    # # training process of a full-connected network with MNIST dataset.
    # for precision in range(0,11):
    #     # precision=8
    #     label_num = 10 # class number of the dataset, maximum is 10.
    #     (train_data,train_labels),(test_data,test_labels) = mnist.load_data()
    #     train_dataset = np.hstack(((train_data/255).reshape((train_data.shape[0],784)),np.array([train_labels]).T))
    #     test_dataset = np.hstack(((test_data/255).reshape((test_data.shape[0],784)),np.array([test_labels]).T))
    #     model = build_model_net(label_num,input_size=784,lr=1e-4)
    #     history = net_model_training_precision(model, train_dataset, test_dataset,  label_num, epochs=200, precision=precision, plotTitle=str(precision)+' Bit precision', savefile=str(precision)+' Bit_MNIST')
    
    # # training process of a convolutional network with MNIST dataset.
    # precision=4
    # label_num = 10 # class number of the dataset, maximum is 10.
    # imageDimension = [28,28,1]
    # (train_data,train_labels),(test_data,test_labels) = mnist.load_data()    
    # train_dataset = np.hstack(((train_data/255).reshape((train_data.shape[0],784)),np.array([train_labels]).T))
    # test_dataset = np.hstack(((test_data/255).reshape((test_data.shape[0],784)),np.array([test_labels]).T))
    # model = build_convnet_model(label_num,imageDimension,kernerNum=1, lr=1e-3) # construct the network
    # history = convnet_model_training_precision(model, train_dataset, test_dataset, imageDimension, label_num, epochs=5, precision=precision, plotTitle=str(precision)+' Bit precision', savefile=str(precision)+' Bit_MNIST')
    
    # # training process of a convolutional network with CIFAR-10 dataset.
    # for precision in range(0,11):
    #     # precision=1
    #     label_num = 10 # class number of the dataset, maximum is 10.
    #     imageDimension = [32,32,3]
    #     (train_data,train_labels),(test_data,test_labels) = cifar10.load_data()    
    #     train_dataset = np.hstack(((train_data/255).reshape((train_data.shape[0],32*32*3)),train_labels))
    #     test_dataset = np.hstack(((test_data/255).reshape((test_data.shape[0],32*32*3)),test_labels))
    #     model = build_convnet_model(label_num,imageDimension,kernerNum=32, lr=1e-3) # construct the network
    #     history = convnet_model_training_precision(model, train_dataset, test_dataset, imageDimension, label_num, epochs=50, precision=precision, plotTitle=str(precision)+' Bit precision', savefile=str(precision)+' Bit_CIFAR-10')
    
    # # 以器件参数替换神经网络权重
    # for precision in [4,5,6,7,8,9]:
    #     # precision = 2
    #     # netname = 'xception'
    #     netname = 'efficientnet'
    #     model=EfficientNetB0(weights='imagenet')
    #     conductances = np.loadtxt(str(precision)+' Bit_params.txt')[:,0]
    #     model = model_weight_change(model,conductances)
    #     model.save(netname+'_'+str(precision)+' Bit.h5', overwrite=True, include_optimizer=True)
    
    # # evaluation of the resnet model
    # netname = 'resnet50'
    # for precision in ['unlimited',9,8,7,6,5,4]:
    #     # precision = 2
    #     # precision = 'unlimited'
    #     sample_num, batch_size = 50000, 500
    #     if precision == 'unlimited':
    #         model=ResNet50(weights='imagenet')
    #     else:
    #         model=models.load_model(netname+'_'+str(precision)+' Bit.h5', custom_objects=None, compile=True)
    #     # model=MobileNet(weights='imagenet')
    #     synsets = list(map(lambda x: x[0][1][0],scio.loadmat('E:/Datasets/ImageNet/ILSVRC2012_devkit_t12/data/meta.mat')['synsets'])) # index+1为ILSVRC2012_ID, 值为WNID
    #     val_truth = list(map(lambda x: int(x), np.loadtxt('E:/Datasets/ImageNet/ILSVRC2012_devkit_t12/data/ILSVRC2012_validation_ground_truth.txt'))) # index+1为验证集中的图片编号值为ILSVRC2012_ID
    #     start_time = time.time()
    #     correct_num_top5 = 0
    #     correct_num_top1 = 0
    #     for j in range(int(sample_num/batch_size)):
    #         results = []
    #         for i in range(j*batch_size+1,(j+1)*batch_size+1):
    #             img_path = 'E:/Datasets/ImageNet/ILSVRC2012_img_val/ILSVRC2012_val_000'+('%05d'%i)+'.JPEG'
                
    #             x=preprocess_image(img_path) # 读取图片并裁剪，参考https://blog.csdn.net/tjuyanming/article/details/105298043?utm_source=miniapp_weixin
    #             # img = keras.utils.load_img(img_path, target_size=(224, 224))
    #             # x = keras.utils.img_to_array(img)
                
    #             x = np.expand_dims(x, axis=0)
    #             x = preprocess_input_resnet50(x) # 随网络类型改变
    #             predicts = model.predict(x,verbose=0)
    #             wnids = list(map(lambda x: x[0], decode_predictions(predicts)[0])) # 得到top5的WNID
    #             if synsets[val_truth[i-1]-1] in wnids:
    #                 correct_num_top5 += 1
    #             if synsets[val_truth[i-1]-1] == wnids[0]:
    #                 correct_num_top1 += 1
    #             results.append(wnids)
    #         with open('ImageNet_val_top5_WNIDs_'+netname+'_'+str(precision)+' Bit.txt','a+') as f:
    #             np.savetxt(f,np.array(results),fmt='%s')
    #         end_time = time.time()
    #         print('Time elapsed: %.2f s\t Samples consumed: %d//%d' % (end_time-start_time,i,sample_num))
    #     os.renames('ImageNet_val_top5_WNIDs_'+netname+'_'+str(precision)+' Bit.txt', 'ImageNet_val_top5_WNIDs_'+netname+'_'+str(precision)+' Bit_top5-accuracy-%.5f_top1-accuracy-%.5f.txt'%(correct_num_top5/sample_num,correct_num_top1/sample_num))

    # # evaluation of the mobilenet model
    # netname = 'mobilenet'
    # for precision in ['unlimited',9,8,7,6,5,4]:
    #     # precision = 2
    #     # precision = 'unlimited'
    #     sample_num, batch_size = 50000, 500
    #     # sample_num, batch_size = 100, 10
    #     if precision == 'unlimited':
    #         model=MobileNet(weights='imagenet')
    #     else:
    #         model=models.load_model(netname+'_'+str(precision)+' Bit.h5', custom_objects=None, compile=True)
    #     # model=MobileNet(weights='imagenet')
    #     synsets = list(map(lambda x: x[0][1][0],scio.loadmat('E:/Datasets/ImageNet/ILSVRC2012_devkit_t12/data/meta.mat')['synsets'])) # index+1为ILSVRC2012_ID, 值为WNID
    #     val_truth = list(map(lambda x: int(x), np.loadtxt('E:/Datasets/ImageNet/ILSVRC2012_devkit_t12/data/ILSVRC2012_validation_ground_truth.txt'))) # index+1为验证集中的图片编号值为ILSVRC2012_ID
    #     start_time = time.time()
    #     correct_num_top5 = 0
    #     correct_num_top1 = 0
    #     for j in range(int(sample_num/batch_size)):
    #         results = []
    #         for i in range(j*batch_size+1,(j+1)*batch_size+1):
    #             img_path = 'E:/Datasets/ImageNet/ILSVRC2012_img_val/ILSVRC2012_val_000'+('%05d'%i)+'.JPEG'
                
    #             x=preprocess_image(img_path) # 读取图片并裁剪，参考https://blog.csdn.net/tjuyanming/article/details/105298043?utm_source=miniapp_weixin
    #             # img = keras.utils.load_img(img_path, target_size=(224, 224))
    #             # x = keras.utils.img_to_array(img)
                
    #             x = np.expand_dims(x, axis=0)
    #             x = preprocess_input_mobilenet(x) # 随网络类型改变
    #             predicts = model.predict(x,verbose=0)
    #             wnids = list(map(lambda x: x[0], decode_predictions(predicts)[0])) # 得到top5的WNID
    #             if synsets[val_truth[i-1]-1] in wnids:
    #                 correct_num_top5 += 1
    #             if synsets[val_truth[i-1]-1] == wnids[0]:
    #                 correct_num_top1 += 1
    #             results.append(wnids)
    #         with open('ImageNet_val_top5_WNIDs_'+netname+'_'+str(precision)+' Bit.txt','a+') as f:
    #             np.savetxt(f,np.array(results),fmt='%s')
    #         end_time = time.time()
    #         print('Time elapsed: %.2f s\t Samples consumed: %d//%d' % (end_time-start_time,i,sample_num))
    #     os.renames('ImageNet_val_top5_WNIDs_'+netname+'_'+str(precision)+' Bit.txt', 'ImageNet_val_top5_WNIDs_'+netname+'_'+str(precision)+' Bit_top5-accuracy-%.5f_top1-accuracy-%.5f.txt'%(correct_num_top5/sample_num,correct_num_top1/sample_num))

    # evaluation of the xception model
    netname = 'xception'
    for precision in ['unlimited',9,8,7,6,5,4]:
        # precision = 2
        # precision = 'unlimited'
        sample_num, batch_size = 50000, 500
        if precision == 'unlimited':
            model=Xception(weights='imagenet')
        else:
            model=models.load_model(netname+'_'+str(precision)+' Bit.h5', custom_objects=None, compile=True)
        # model=MobileNet(weights='imagenet')
        synsets = list(map(lambda x: x[0][1][0],scio.loadmat('E:/Datasets/ImageNet/ILSVRC2012_devkit_t12/data/meta.mat')['synsets'])) # index+1为ILSVRC2012_ID, 值为WNID
        val_truth = list(map(lambda x: int(x), np.loadtxt('E:/Datasets/ImageNet/ILSVRC2012_devkit_t12/data/ILSVRC2012_validation_ground_truth.txt'))) # index+1为验证集中的图片编号值为ILSVRC2012_ID
        start_time = time.time()
        correct_num_top5 = 0
        correct_num_top1 = 0
        for j in range(int(sample_num/batch_size)):
            results = []
            for i in range(j*batch_size+1,(j+1)*batch_size+1):
                img_path = 'E:/Datasets/ImageNet/ILSVRC2012_img_val/ILSVRC2012_val_000'+('%05d'%i)+'.JPEG'
                
                # x=preprocess_image(img_path) # 读取图片并裁剪，参考https://blog.csdn.net/tjuyanming/article/details/105298043?utm_source=miniapp_weixin
                img = keras.utils.load_img(img_path, target_size=(299, 299))
                x = keras.utils.img_to_array(img)
                
                x = np.expand_dims(x, axis=0)
                x = preprocess_input_xception(x) # 随网络类型改变
                predicts = model.predict(x,verbose=0)
                wnids = list(map(lambda x: x[0], decode_predictions(predicts)[0])) # 得到top5的WNID
                if synsets[val_truth[i-1]-1] in wnids:
                    correct_num_top5 += 1
                if synsets[val_truth[i-1]-1] == wnids[0]:
                    correct_num_top1 += 1
                results.append(wnids)
            with open('ImageNet_val_top5_WNIDs_'+netname+'_'+str(precision)+' Bit.txt','a+') as f:
                np.savetxt(f,np.array(results),fmt='%s')
            end_time = time.time()
            print('Time elapsed: %.2f s\t Samples consumed: %d//%d' % (end_time-start_time,i,sample_num))
        os.renames('ImageNet_val_top5_WNIDs_'+netname+'_'+str(precision)+' Bit.txt', 'ImageNet_val_top5_WNIDs_'+netname+'_'+str(precision)+' Bit_top5-accuracy-%.5f_top1-accuracy-%.5f.txt'%(correct_num_top5/sample_num,correct_num_top1/sample_num))

    # # evaluation of the EfficientNet model
    # netname = 'efficientnet'
    # for precision in ['unlimited',9,8,7,6,5,4]:
    #     # precision = 2
    #     # precision = 'unlimited'
    #     sample_num, batch_size = 50000, 500
    #     if precision == 'unlimited':
    #         model=EfficientNetB0(weights='imagenet')
    #     else:
    #         model=models.load_model(netname+'_'+str(precision)+' Bit.h5', custom_objects=None, compile=True)
    #     # model=MobileNet(weights='imagenet')
    #     synsets = list(map(lambda x: x[0][1][0],scio.loadmat('E:/Datasets/ImageNet/ILSVRC2012_devkit_t12/data/meta.mat')['synsets'])) # index+1为ILSVRC2012_ID, 值为WNID
    #     val_truth = list(map(lambda x: int(x), np.loadtxt('E:/Datasets/ImageNet/ILSVRC2012_devkit_t12/data/ILSVRC2012_validation_ground_truth.txt'))) # index+1为验证集中的图片编号值为ILSVRC2012_ID
    #     start_time = time.time()
    #     correct_num_top5 = 0
    #     correct_num_top1 = 0
    #     for j in range(int(sample_num/batch_size)):
    #         results = []
    #         for i in range(j*batch_size+1,(j+1)*batch_size+1):
    #             img_path = 'E:/Datasets/ImageNet/ILSVRC2012_img_val/ILSVRC2012_val_000'+('%05d'%i)+'.JPEG'
                
    #             x=preprocess_image(img_path) # 读取图片并裁剪，参考https://blog.csdn.net/tjuyanming/article/details/105298043?utm_source=miniapp_weixin
    #             # img = keras.utils.load_img(img_path, target_size=(300, 300))
    #             # x = keras.utils.img_to_array(img)
                
    #             x = np.expand_dims(x, axis=0)
    #             x = preprocess_input_efficientnet(x) # 随网络类型改变
    #             predicts = model.predict(x,verbose=0)
    #             wnids = list(map(lambda x: x[0], decode_predictions(predicts)[0])) # 得到top5的WNID
    #             if synsets[val_truth[i-1]-1] in wnids:
    #                 correct_num_top5 += 1
    #             if synsets[val_truth[i-1]-1] == wnids[0]:
    #                 correct_num_top1 += 1
    #             results.append(wnids)
    #         with open('ImageNet_val_top5_WNIDs_'+netname+'_'+str(precision)+' Bit.txt','a+') as f:
    #             np.savetxt(f,np.array(results),fmt='%s')
    #         end_time = time.time()
    #         print('Time elapsed: %.2f s\t Samples consumed: %d//%d' % (end_time-start_time,i,sample_num))
    #     os.renames('ImageNet_val_top5_WNIDs_'+netname+'_'+str(precision)+' Bit.txt', 'ImageNet_val_top5_WNIDs_'+netname+'_'+str(precision)+' Bit_top5-accuracy-%.5f_top1-accuracy-%.5f.txt'%(correct_num_top5/sample_num,correct_num_top1/sample_num))
